3.3 KNN算法与图像分类

🎯 学习目标

通过手写数字识别器项目,掌握K近邻算法和图像分类的基本概念,包括:

  • 理解KNN算法的原理和工作机制
  • 学会处理图像数据并进行特征提取
  • 掌握图像分类的评估方法
  • 学会使用混淆矩阵分析模型性能
  • 掌握参数调优和数据可视化技术

📋 项目预览

我们将创建一个手写数字识别器,能够识别0-9的手写数字图像。通过学习KNN算法,理解"物以类聚,人以群分"的机器学习思想。

🧠 核心概念详解

1. KNN算法原理

KNN(K-Nearest Neighbors) 的核心思想:

"看看你的邻居是谁,你就可能是谁"

算法步骤

  1. 计算新样本与所有训练样本的距离
  2. 找出距离最近的K个邻居
  3. 根据邻居的类别进行投票
  4. 将得票最多的类别作为预测结果

生活化比喻

  • 你想知道一部电影好不好看
  • 你问K个看过这部电影的朋友
  • 如果大多数朋友说好看,你就认为电影好看

2. 距离度量

欧几里得距离(最常用):

距离 = √[(x₁-x₂)² + (y₁-y₂)² + ...]

曼哈顿距离

距离 = |x₁-x₂| + |y₁-y₂| + ...

在手写数字识别中的应用

  • 每个像素看作一个维度
  • 64个像素 → 64维空间中的点
  • 距离近的数字图像更相似

3. K值的选择

K值的影响

  • K太小:容易受噪声影响,过拟合
  • K太大:可能包含不相关的样本,欠拟合

K值选择原则

  • 通常选择奇数,避免平票
  • 通过交叉验证选择最优K值
  • 经验值:K = √n(n为样本数)

4. 图像数据的处理

手写数字数据集特点

  • 图像尺寸:8×8像素
  • 每个像素值:0-16(灰度值)
  • 总共64个特征(像素)
  • 10个类别(数字0-9)

图像到向量的转换

8×8图像 → 展平为64维向量
[ [1,2,3,...],      → [1,2,3,...,64]
  [4,5,6,...],
  ...            ]

5. 分类评估指标

混淆矩阵(Confusion Matrix)

  • 行:真实类别
  • 列:预测类别
  • 对角线:正确分类
  • 非对角线:分类错误

多分类评估指标

  • 准确率:总体正确率
  • 精确率:预测为正的样本中真正为正的比例
  • 召回率:实际为正的样本中被正确预测的比例
  • F1分数:精确率和召回率的调和平均

🔧 代码实现详解

1. 数据加载和探索

from sklearn.datasets import load_digits

# 加载手写数字数据集
digits = load_digits()

# 查看数据集信息
print("图像数量:", len(digits.images))
print("图像尺寸:", digits.images[0].shape)
print("特征数量:", len(digits.data[0]))
print("类别数量:", len(digits.target_names))

数据探索要点

  • 了解数据的基本结构
  • 查看样本分布是否均衡
  • 可视化一些样本图像

2. 数据可视化

import matplotlib.pyplot as plt

# 显示样本图像
fig, axes = plt.subplots(2, 5, figsize=(10, 4))
for i, ax in enumerate(axes.flat):
    ax.imshow(digits.images[i], cmap='gray')
    ax.set_title(f'数字: {digits.target[i]}')
    ax.axis('off')

可视化作用

  • 直观理解数据
  • 发现数据质量问题
  • 为后续分析提供参考

3. KNN模型训练

from sklearn.neighbors import KNeighborsClassifier

# 创建KNN分类器
knn = KNeighborsClassifier(n_neighbors=3)

# 训练模型
knn.fit(X_train, y_train)

KNN特点

  • 惰性学习:训练时只存储数据,不进行复杂计算
  • 无需训练时间:但预测时需要计算所有距离
  • 对数据规模敏感:大数据集预测较慢

4. 模型预测和评估

from sklearn.metrics import accuracy_score, classification_report

# 预测测试集
y_pred = knn.predict(X_test)

# 计算准确率
accuracy = accuracy_score(y_test, y_pred)

# 详细分类报告
print(classification_report(y_test, y_pred))

5. 混淆矩阵可视化

from sklearn.metrics import confusion_matrix
import seaborn as sns

# 生成混淆矩阵
cm = confusion_matrix(y_test, y_pred)

# 热力图可视化
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues')
plt.xlabel('预测标签')
plt.ylabel('真实标签')

📊 完整代码解析

数据加载和预处理

# 加载数据集
digits = load_digits()

# 数据划分
X_train, X_test, y_train, y_test = train_test_split(
    digits.data, digits.target, test_size=0.2, random_state=42
)
  • 直接使用scikit-learn提供的数据集
  • 保持数据划分的一致性

KNN模型训练

knn = KNeighborsClassifier(n_neighbors=3)
knn.fit(X_train, y_train)
  • 选择K=3作为初始值
  • 使用默认的欧几里得距离

预测概率分析

# 获取预测概率
probabilities = knn.predict_proba([test_sample])[0]

# 找出最可能的3个类别
top3_indices = np.argsort(probabilities)[-3:][::-1]
  • predict_proba返回每个类别的概率
  • 通过排序找出最可能的类别

K值选择实验

k_values = range(1, 11)
accuracies = []

for k in k_values:
    knn_temp = KNeighborsClassifier(n_neighbors=k)
    knn_temp.fit(X_train, y_train)
    accuracy_temp = accuracy_score(y_test, knn_temp.predict(X_test))
    accuracies.append(accuracy_temp)
  • 测试不同K值的效果
  • 帮助选择最优K值

数据降维可视化

from sklearn.decomposition import PCA

# 主成分分析降维
pca = PCA(n_components=2)
X_pca = pca.fit_transform(digits.data)

# 可视化
plt.scatter(X_pca[:, 0], X_pca[:, 1], c=digits.target, cmap='tab10')
  • 将64维数据降到2维
  • 直观展示数据的聚类情况

🎯 学习要点总结

  1. KNN算法原理:理解基于距离的分类思想
  2. 距离度量:掌握欧几里得距离和曼哈顿距离
  3. K值选择:学会通过实验选择最优K值
  4. 图像数据处理:掌握图像到向量的转换方法
  5. 多分类评估:学会使用混淆矩阵和分类报告
  6. 预测概率:理解概率预测和置信度
  7. 数据可视化:掌握多种可视化技术
  8. 参数调优:学会系统性地优化模型参数

💡 练习建议

基础练习

  1. 修改K值:尝试K=1,5,10等不同值,观察准确率变化
  2. 改变距离度量:尝试使用曼哈顿距离或其他距离
  3. 调整数据比例:改变训练集和测试集的比例

进阶练习

  1. 特征标准化:添加数据标准化步骤,观察对KNN的影响
  2. 加权KNN:根据距离给邻居不同的投票权重
  3. 维度灾难:理解高维空间中距离计算的问题

扩展练习

  1. 其他数据集:在MNIST等更大的手写数字数据集上应用
  2. 图像预处理:添加图像增强、去噪等预处理步骤
  3. 实时识别:实现摄像头实时手写数字识别
  4. 自定义分类:训练识别自己手写数字的模型

🔍 常见问题解答

Q: 为什么KNN在大数据集上运行慢?

A: 因为KNN需要计算新样本与所有训练样本的距离,时间复杂度为O(n)。

Q: 如何提高KNN的效率?

A: 可以使用KD树、球树等数据结构加速距离计算,或使用近似最近邻算法。

Q: KNN对特征尺度敏感吗?

A: 非常敏感!不同尺度的特征会影响距离计算,需要进行特征标准化。

Q: KNN适合处理什么类型的数据?

A: 适合数值型数据,类别型数据需要特殊处理(如独热编码)。

🚀 下一步学习

完成KNN项目后,你可以:

  • 学习支持向量机(SVM)处理复杂的分类边界
  • 探索决策树和随机森林处理特征重要性
  • 了解神经网络处理更复杂的图像识别任务
  • 学习聚类算法如K-means进行无监督学习

记住:KNN是理解机器学习"相似性"概念的绝佳起点,它的思想在很多高级算法中都有体现!

« 上一篇 3.2 线性回归与预测 下一篇 » 4.1 深度学习基础